In [None]:
# %pip install -q diffusers transformers xformers accelerate ray safetensors
# %pip install -q numpy scipy ftfy imageio matplotlib Pillow gradio
# %pip install -q python-dotenv

In [None]:
from dotenv import load_dotenv
load_dotenv()

import os
# your models cache will be stored here
# os.environ['HUGGINGFACE_HUB_CACHE'] = 'D:\\Code\\Huggingface_cache\\'

In [None]:
import torch
import numpy as np
import ray

import datetime
import json

from PIL import Image
from tqdm.auto import tqdm

from diffusers import (
    StableDiffusionPipeline, 
    StableDiffusionInpaintPipeline, 
    StableDiffusionImg2ImgPipeline, 
    CycleDiffusionPipeline, 
    StableDiffusionDepth2ImgPipeline
)

from diffusers import (
    DDIMScheduler, 
    PNDMScheduler, 
    LMSDiscreteScheduler, 
    DPMSolverMultistepScheduler, 
    EulerAncestralDiscreteScheduler, 
    EulerDiscreteScheduler
)


from accelerate.hooks import remove_hook_from_submodules

from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
from transformers import logging
logging.set_verbosity_error()

ray.init()

In [None]:
print(torch.cuda.is_available())

In [None]:
#optimization https://huggingface.co/docs/diffusers/optimization/fp16#memory-efficient-attention
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# get token: https://huggingface.co/settings/tokens
HF_TOKEN = "YOUR_TOKEN"

In [None]:
model_id = "runwayml/stable-diffusion-v1-5" #"ckpt/anything-v4.5"#"andite/anything-v4.0"#"SG161222/Realistic_Vision_V1.4" #"Linaqruf/anything-v3.0" #"admruul/anything-v3.0" #"stabilityai/stable-diffusion-2-1"  #"CompVis/stable-diffusion-v1-4" #"hakurei/waifu-diffusion" #"runwayml/stable-diffusion-v1-5"
inpaint_model_id = "runwayml/stable-diffusion-inpainting"
depth2img_model_id = "stabilityai/stable-diffusion-2-depth"
# model_id = "D:\\Code\\Huggingface_cache\\800\\"

In [None]:
@ray.remote(num_gpus=1)   
class StableDiffusionInterface(): 
    def __init__(self, model_name, scheduler, use_auth_token = None, torch_dtype = torch.float32, safe_mode = False, device = "cuda", revision = "fp16",):
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_name, 
            scheduler = scheduler,  
            torch_dtype = torch_dtype,  
            use_auth_token = use_auth_token,
        )
        self.pipe.safety_checker = None

        #optimization
        self.pipe.enable_model_cpu_offload()
        self.pipe.enable_attention_slicing(1)
        self.pipe.unet.to(memory_format=torch.channels_last)
        self.pipe.enable_vae_slicing()
        self.pipe.enable_vae_tiling()
        self.pipe.enable_xformers_memory_efficient_attention()

        self.model_name = model_name

    def remove_hooks(self):
        self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.disable_attention_slicing()
        remove_hook_from_submodules(self.pipe.vae)
        remove_hook_from_submodules(self.pipe.text_encoder)
        remove_hook_from_submodules(self.pipe.unet)
        
    def set_scheduler(self, scheduler):
        self.pipe.scheduler = scheduler

    def name(self):
        return self.__class__.__name__
    
    def model_name(self):
        return self.model_name
   
    def __call__(
        self, 
        prompt = "", 
        height = 64, 
        width = 64, 
        negative_prompt = "", 
        num_images_per_prompt = 1, 
        num_inference_steps = 50, 
        guidance_scale = 7.5, 
        seed = None
    ):
        
        g_cuda = None
        if seed is not None:
            g_cuda = torch.Generator(device='cuda')
            g_cuda.manual_seed(seed)

        return self.pipe(
            prompt,
            height=height,
            width=width,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator = g_cuda
        ).images

In [None]:
@ray.remote(num_gpus=1)   
class CycleDiffusionInterface:
    def __init__(self, model_name, scheduler, use_auth_token = None, torch_dtype = torch.float32, safe_mode = False, device = "cuda", revision = "fp16",):
        self.pipe = CycleDiffusionPipeline.from_pretrained(
            model_name, 
            scheduler = scheduler,  
            torch_dtype=torch_dtype,  
            use_auth_token = use_auth_token,
            # revision=revision, 
        )
        self.pipe.safety_checker = None

        self.pipe.enable_model_cpu_offload()
        self.pipe.enable_attention_slicing(1)
        self.pipe.unet.to(memory_format=torch.channels_last)
        # self.pipe.enable_vae_slicing() #Not working with cycle stable diffusion pipeline
        # self.pipe.enable_vae_tiling() #Not working with cycle stable diffusion pipeline
        self.pipe.enable_xformers_memory_efficient_attention()

        self.model_name = model_name

    def remove_hooks(self):
        self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.disable_attention_slicing()
        remove_hook_from_submodules(self.pipe.vae)
        remove_hook_from_submodules(self.pipe.text_encoder)
        remove_hook_from_submodules(self.pipe.unet)

    def set_scheduler(self, scheduler):
        self.pipe.scheduler = scheduler

    def name(self):
        return self.__class__.__name__
    
    def model_name(self):
        return self.model_name
    
    def __call__(
        self, 
        prompt = "",  
        source_prompt = "", 
        image = None, 
        height = 64, 
        width = 64,  
        num_images_per_prompt = 1, 
        num_inference_steps = 50, 
        eta=0.1,
        strength=0.85,
        guidance_scale = 7.5, 
        source_guidance_scale=1,
        seed = None
    ):
        
        g_cuda = None
        if seed is not None:
            g_cuda = torch.Generator(device='cuda')
            g_cuda.manual_seed(seed)

        return self.pipe(
            prompt,
            source_prompt = source_prompt,
            image = image.resize((width, height)),
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            eta=eta,
            strength=strength,
            guidance_scale=guidance_scale,
            source_guidance_scale=source_guidance_scale,
            generator = g_cuda
        ).images

In [None]:
@ray.remote(num_gpus=1)   
class Img2ImgInterface:
    def __init__(self, model_name, scheduler, use_auth_token = None, torch_dtype = torch.float32, safe_mode = False, device = "cuda", revision = "fp16",):
        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            model_name, 
            scheduler = scheduler,  
            torch_dtype=torch_dtype,  
            use_auth_token = use_auth_token,
            # revision=revision, 
        )
        self.pipe.safety_checker = None

        self.pipe.enable_model_cpu_offload()
        self.pipe.enable_attention_slicing(1)
        self.pipe.unet.to(memory_format=torch.channels_last)
        # self.pipe.enable_vae_slicing() #Not working with img2img stable diffusion pipeline
        # self.pipe.enable_vae_tiling() #Not working with img2img stable diffusion pipeline
        self.pipe.enable_xformers_memory_efficient_attention()

        self.model_name = model_name

    def remove_hooks(self):
        self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.disable_attention_slicing()
        remove_hook_from_submodules(self.pipe.vae)
        remove_hook_from_submodules(self.pipe.text_encoder)
        remove_hook_from_submodules(self.pipe.unet)

    def set_scheduler(self, scheduler):
        self.pipe.scheduler = scheduler

    def name(self):
        return self.__class__.__name__
    
    def model_name(self):
        return self.model_name
    
    def __call__(
        self, 
        prompt = "",  
        negative_prompt = "", 
        image = None, 
        height = 64, 
        width = 64,  
        num_images_per_prompt = 1, 
        num_inference_steps = 50, 
        eta=0.1,
        strength=0.85,
        guidance_scale = 7.5, 
        seed = None
    ):
        
        g_cuda = None
        if seed is not None:
            g_cuda = torch.Generator(device='cuda')
            g_cuda.manual_seed(seed)

        return self.pipe(
            prompt,
            negative_prompt = negative_prompt,
            image = image.resize((width, height)),
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            eta=eta,
            strength=strength,
            guidance_scale=guidance_scale,
            generator = g_cuda
        ).images

In [None]:
@ray.remote(num_gpus=1)   
class InpaintInterface:
    def __init__(self, model_name, scheduler, use_auth_token = None, torch_dtype = torch.float32, safe_mode = False, device = "cuda", revision = "fp16"):
        self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
            model_name, 
            scheduler = scheduler,  
            torch_dtype=torch_dtype,  
            use_auth_token = use_auth_token,
        )
        self.pipe.safety_checker = None

        self.pipe.enable_model_cpu_offload()
        self.pipe.enable_attention_slicing(1)
        self.pipe.unet.to(memory_format=torch.channels_last)
        # self.pipe.enable_vae_slicing() #Not working with inpaint stable diffusion pipeline
        # self.pipe.enable_vae_tiling() #Not working with inpaint stable diffusion pipeline
        self.pipe.enable_xformers_memory_efficient_attention()
        
        self.model_name = model_name

    
    def remove_hooks(self):
        self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.disable_attention_slicing()
        remove_hook_from_submodules(self.pipe.vae)
        remove_hook_from_submodules(self.pipe.text_encoder)
        remove_hook_from_submodules(self.pipe.unet)

    def set_scheduler(self, scheduler):
        self.pipe.scheduler = scheduler
        
    def name(self):
        return self.__class__.__name__
    
    def model_name(self):
        return self.model_name
    
    def __call__(
        self, 
        prompt = "",  
        negative_prompt = "", 
        image = None, 
        mask_image = None,
        height = 64, 
        width = 64,  
        num_images_per_prompt = 1, 
        num_inference_steps = 50, 
        eta=0.1,
        guidance_scale = 7.5, 
        seed = None
    ):
        
        g_cuda = None
        if seed is not None:
            g_cuda = torch.Generator(device='cuda')
            g_cuda.manual_seed(seed)

        return self.pipe(
            prompt,
            negative_prompt = negative_prompt,
            image = image.resize((width, height)),
            mask_image = mask_image.resize((width, height)),
            height = height,
            width = width,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            eta=eta,
            guidance_scale=guidance_scale,
            generator = g_cuda
        ).images

In [None]:
@ray.remote(num_gpus=1)   
class Depth2ImgInterface:
    def __init__(self, model_name, scheduler, use_auth_token = None, torch_dtype = torch.float32, safe_mode = False, device = "cuda", revision = "fp16"):
        self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
            model_name, 
            scheduler = scheduler,  
            torch_dtype=torch_dtype,  
            use_auth_token = use_auth_token,
        )
        self.pipe.safety_checker = None

        # self.pipe.enable_model_cpu_offload()
        self.pipe.enable_sequential_cpu_offload()
        self.pipe.enable_attention_slicing(1)
        self.pipe.unet.to(memory_format=torch.channels_last)
        # self.pipe.enable_vae_slicing() #Not working with upscale stable diffusion pipeline
        # self.pipe.enable_vae_tiling() #Not working with upscale stable diffusion pipeline
        self.pipe.enable_xformers_memory_efficient_attention()

        self.model_name = model_name

    
    def remove_hooks(self):
        self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.disable_attention_slicing()
        remove_hook_from_submodules(self.pipe.vae)
        remove_hook_from_submodules(self.pipe.text_encoder)
        remove_hook_from_submodules(self.pipe.unet)

    def set_scheduler(self, scheduler):
        self.pipe.scheduler = scheduler
        
    def name(self):
        return self.__class__.__name__
    
    def model_name(self):
        return self.model_name
    
    def __call__(
        self, 
        prompt = "",  
        negative_prompt = "", 
        image = None, 
        height = 64, 
        width = 64,  
        num_images_per_prompt = 1, 
        num_inference_steps = 50, 
        eta=0.1,
        strength=0.85,
        guidance_scale = 7.5, 
        seed = None
    ):
        
        g_cuda = None
        if seed is not None:
            g_cuda = torch.Generator(device='cuda')
            g_cuda.manual_seed(seed)

        return self.pipe(
            prompt,
            negative_prompt = negative_prompt,
            image = image.resize((width, height)),
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            eta=eta,
            strength=strength,
            guidance_scale=guidance_scale,
            generator = g_cuda
        ).images

In [None]:
# schedulers for the diffusion process.
schedulers = {
    "EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"),
    "EulerDiscreteScheduler": EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"),
    "DDIMScheduler": DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
    "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler"),
    "LMSDiscreteScheduler": LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"),
    "PNDMScheduler": PNDMScheduler.from_pretrained(model_id, subfolder="scheduler"),
}


In [None]:
stable_diffusion_interface = StableDiffusionInterface.remote(
    model_id,
    scheduler = schedulers["EulerAncestralDiscreteScheduler"],  
    use_auth_token = HF_TOKEN,
)

# stable_diffusion_interface = Img2ImgInterface(
#         model_id, 
#         vae, 
#         unet, 
#         text_encoder, 
#         tokenizer,
#         safety_checker = None,
#         feature_extractor = None,
#         scheduler = schedulers['DDIMScheduler'],  
#         use_auth_token = HF_TOKEN,
#     )

# cycle_diffusion_interface = CycleDiffusionInterface(
#     model_id, 
#     vae, 
#     unet, 
#     text_encoder, 
#     tokenizer,
#     safety_checker = None,
#     feature_extractor = None,
#     scheduler = schedulers["PNDMScheduler"],  
#     use_auth_token = HF_TOKEN,
# )
# cycle_diffusion_interface = None



# stable_diffusion_interface = CycleDiffusionInterface(
#             model_id, 
#             vae, 
#             unet, 
#             text_encoder, 
#             tokenizer,
#             safety_checker = None,
#             feature_extractor = None,
#             scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),  
#             use_auth_token = HF_TOKEN,
#         )


# images = stable_diffusion_interface(
#         prompt = "",
#         source_prompt = "",
#         image = Image.open("C:\\Users\\Rustam\\Downloads\\2023-03-26 04-02-23.913974.jpg"),
#         height = 384,
#         width = 384,
#         num_images_per_prompt = 1,
#         num_inference_steps = 30,
#         eta = 0.1,
#         strength = 0.8,
#         guidance_scale =7.5,
#         source_guidance_scale = 1,
#         seed = None,
#     )



In [None]:
import gradio as gr

save_path = ""

def save_images(images, save_path):       
    try:
        for image in images:
            # display(image)
            if len(save_path) > 0:
                curr_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f').replace(":", "-") #pytz.timezone('Europe/Moscow')
                image.save(f"{save_path}/{curr_date}.jpg")
    except:
        print("Couldn't save image")

def save_prompts(prompt, neg_prompt, style_name):
    filename = 'prompts.json'

    try:
        with open(filename, "r") as f:
            existing_prompts = json.load(f)
    except FileNotFoundError:
        existing_prompts = {"prompts": []}

    if style_name == "":
        curr_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f').replace(":", "-")
        style_name = f"default_{curr_date}"
    
    styles_list = [style["style_name"] for style in existing_prompts["prompts"]]
    if style_name in styles_list:
        existing_prompts["prompts"][styles_list.index(style_name)] = {"prompt": prompt, "neg_prompt": neg_prompt, "style_name": style_name}
    else:
        existing_prompts["prompts"].append({"prompt": prompt, "neg_prompt": neg_prompt, "style_name": style_name})

    with open(filename, "w") as f:
        json.dump(existing_prompts, f, indent=4)

    return gr.Dropdown.update(choices = get_styles(), value = "")

def load_prompts(style_name):
    filename = 'prompts.json'

    
    try:
        with open(filename, "r") as f:
            existing_prompts = json.load(f)
    except FileNotFoundError:
        return [
            gr.update(),
            gr.update(),
        ]
  
    styles_list = [style["style_name"] for style in existing_prompts["prompts"]]

    if style_name not in styles_list:
        return [
            gr.update(),
            gr.update(),
        ]
    return [
        gr.update(value = existing_prompts["prompts"][styles_list.index(style_name)]["prompt"]),
        gr.update(value = existing_prompts["prompts"][styles_list.index(style_name)]["neg_prompt"])
    ]

def get_styles(filename = 'prompts.json'):
    try:
        with open(filename, "r") as f:
            existing_prompts = json.load(f)
    except FileNotFoundError:
        existing_prompts = {"prompts": []}
    styles_list = [style["style_name"] for style in existing_prompts["prompts"]]
    return styles_list



def generate_images(
    model_id,
    prompt, 
    negative_prompt = "", 
    num_samples = 1, 
    guidance_scale = 7.5, 
    num_inference_steps = 25, 
    height = 512, 
    width = 512, 
    seed = None, 
    save_path = "",
    scheduler_name = "EulerAncestralDiscreteScheduler",
):

    seed = None if int(seed) == -1 else abs(int(seed))

    global stable_diffusion_interface
    if ray.get(stable_diffusion_interface.name.remote()) != "StableDiffusionInterface" or ray.get(stable_diffusion_interface.model_name.remote()) != model_id:
        ray.kill(stable_diffusion_interface)
        stable_diffusion_interface = StableDiffusionInterface.remote(
            model_id,
            use_auth_token = HF_TOKEN,
        )
    ray.get(stable_diffusion_interface.set_scheduler.remote(schedulers[scheduler_name]))

    images = ray.get(stable_diffusion_interface.__call__.remote(
        prompt,
        negative_prompt = negative_prompt,
        height = height,
        width = width,
        num_images_per_prompt = num_samples,
        num_inference_steps = num_inference_steps,
        guidance_scale =guidance_scale,
        seed = seed,
    ))

    for image in images:
        if type(image) is np.ndarray:
            image = Image.fromarray(image)

    save_images(images, save_path)
    return images


def cycle_generate_images(
    model_id,
    prompt= "", 
    source_prompt = "",
    num_samples = 1, 
    eta=0.1,
    strength=0.85,
    guidance_scale = 7.5, 
    source_guidance_scale=1,
    num_inference_steps = 25, 
    image = None,
    height = 512, 
    width = 512, 
    seed = None, 
    save_path = "",
    scheduler_name = "DDIMScheduler", 
):
    seed = None if int(seed) == -1 else abs(int(seed))

    image = image['image']
    if type(image) is not Image.Image:
        print("Cycle Diffusion Pipeline: Couldn't open image")

        return None       

    global stable_diffusion_interface
    if ray.get(stable_diffusion_interface.name.remote()) != "CycleDiffusionInterface" or ray.get(stable_diffusion_interface.model_name.remote()) != model_id:
        ray.kill(stable_diffusion_interface)
        
        stable_diffusion_interface = CycleDiffusionInterface.remote(
            model_id,
            use_auth_token = HF_TOKEN,
        )
    ray.get(stable_diffusion_interface.set_scheduler.remote(schedulers[scheduler_name]))
    
    images = ray.get(stable_diffusion_interface.__call__.remote(
        prompt = prompt,
        source_prompt = source_prompt,
        image = image,
        height = height,
        width = width,
        num_images_per_prompt = num_samples,
        num_inference_steps = num_inference_steps,
        eta = eta,
        strength = strength,
        guidance_scale =guidance_scale,
        source_guidance_scale = source_guidance_scale,
        seed = seed,
    ))

    for image in images:
        if type(image) is np.ndarray:
            image = Image.fromarray(image)

    save_images(images, save_path)
    return images

def img2img_generate_images(
    model_id,
    prompt = "",
    negative_prompt = "",
    num_samples = 1,
    guidance_scale = 7.5,
    eta = 0.1,
    strength = 0.85,
    num_inference_steps = 25,
    image = None,
    height = 512,
    width = 512,
    seed = None,
    save_path = "",
    scheduler_name = "DDIMScheduler",
):
    seed = None if int(seed) == -1 else abs(int(seed))

    image = image['image']
    if type(image) is not Image.Image:
        print("Img2Img Pipeline: Couldn't open image")
        return None

    global stable_diffusion_interface
    if ray.get(stable_diffusion_interface.name.remote()) != "Img2ImgInterface" or ray.get(stable_diffusion_interface.model_name.remote()) != model_id:
        ray.kill(stable_diffusion_interface)
        
        stable_diffusion_interface = Img2ImgInterface.remote(
            model_id,
            use_auth_token = HF_TOKEN,
        )
    ray.get(stable_diffusion_interface.set_scheduler.remote(schedulers[scheduler_name]))
    
    images = ray.get(stable_diffusion_interface.__call__.remote(
        prompt = prompt,
        negative_prompt = negative_prompt,
        image = image,
        height = height,
        width = width,
        num_images_per_prompt = num_samples,
        num_inference_steps = num_inference_steps,
        eta = eta,
        strength = strength,
        guidance_scale = guidance_scale,
        seed = seed,
    ))

    for image in images:
        if type(image) is np.ndarray:
            image = Image.fromarray(image)

    save_images(images, save_path)
    return images

def inpaint_generate_images(
    model_id,
    prompt = "",
    negative_prompt = "",
    num_samples = 1,
    guidance_scale = 7.5,
    eta = 0.1,
    num_inference_steps = 25,
    image = None,
    height = 512,
    width = 512,
    seed = None,
    save_path = "",
    scheduler_name = "DDIMScheduler",
):
    seed = None if int(seed) == -1 else abs(int(seed))

    image, mask_image = image['image'], image['mask']
    if type(image) is not Image.Image:
        print("Inpaint Pipeline: Couldn't open image")
        return None
 
    global stable_diffusion_interface
    if ray.get(stable_diffusion_interface.name.remote()) != "InpaintInterface" or ray.get(stable_diffusion_interface.model_name.remote()) != model_id:
        ray.kill(stable_diffusion_interface)
        
        stable_diffusion_interface = InpaintInterface.remote(
            model_id,
            use_auth_token = HF_TOKEN,
        )
    ray.get(stable_diffusion_interface.set_scheduler.remote(schedulers[scheduler_name]))
    
    images = ray.get(stable_diffusion_interface.__call__.remote(
        prompt = prompt,
        negative_prompt = negative_prompt,
        image = image,
        mask_image = mask_image,
        height = height,
        width = width,
        num_images_per_prompt = num_samples,
        num_inference_steps = num_inference_steps,
        eta = eta,
        guidance_scale = guidance_scale,
        seed = seed,

    ))

    for image in images:
        if type(image) is np.ndarray:
            image = Image.fromarray(image)

    save_images(images, save_path)
    return images

def depth2img_generate_images(
    model_id,   
    prompt = "",
    negative_prompt = "",
    num_samples = 1,
    guidance_scale = 7.5,
    eta = 0.1,
    strength = 0.85,
    num_inference_steps = 25,
    image = None,
    height = 512,
    width = 512,
    seed = None,
    save_path = "",
    scheduler_name = "DDIMScheduler"
):
    seed = None if int(seed) == -1 else abs(int(seed))

    image, mask_image = image['image'], image['mask']
    if type(image) is not Image.Image:
        print("Depth2img Pipeline: Couldn't open image")
        return None
    
    global stable_diffusion_interface
    if ray.get(stable_diffusion_interface.name.remote()) != "Depth2ImgInterface" or ray.get(stable_diffusion_interface.model_name.remote()) != model_id:
        ray.kill(stable_diffusion_interface)
        
        stable_diffusion_interface = Depth2ImgInterface.remote(
            model_id,
            use_auth_token = HF_TOKEN,
        )
    ray.get(stable_diffusion_interface.set_scheduler.remote(schedulers[scheduler_name]))

    images = ray.get(stable_diffusion_interface.__call__.remote(
        prompt = prompt,
        negative_prompt = negative_prompt,
        image = image,
        height = height,
        width = width,
        num_images_per_prompt = num_samples,
        num_inference_steps = num_inference_steps,
        eta = eta,
        strength = strength,
        guidance_scale = guidance_scale,
        seed = seed,

    ))

    for image in images:
        if type(image) is np.ndarray:
            image = Image.fromarray(image)

    save_images(images, save_path)
    return images
    

def select_interface(interface_name):
    if interface_name == "Stable Diffusion pipeline":
        return {
            negative_prompt: gr.update(visible=True),
            source_prompt: gr.update(visible=False),
            scheduler: gr.update(choices=[
                "EulerAncestralDiscreteScheduler",
                "EulerDiscreteScheduler",
                "DDIMScheduler",
                "DPMSolverMultistepScheduler",
                "LMSDiscreteScheduler",
                "PNDMScheduler"
                ],
                value="EulerAncestralDiscreteScheduler", 
            ),
            eta: gr.update(visible=False),
            strength: gr.update(visible=False),
            source_guidance_scale: gr.update(visible=False),
            image_input: gr.update(visible=False),
            generate_button: gr.update(visible=True),
            cycle_generate_button: gr.update(visible=False),
            img2img_generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=False),
            depth2img_generate_button: gr.update(visible=False),
        }

    elif interface_name == "Cycle Diffusion pipeline":
        return {
            negative_prompt: gr.update(visible=False),
            source_prompt: gr.update(visible=True),
            scheduler: gr.update(choices=[
                "DDIMScheduler",
                ],
                value="DDIMScheduler",
            ),
            eta: gr.update(visible=True),
            strength: gr.update(visible=True),
            source_guidance_scale: gr.update(visible=True),
            image_input: gr.update(visible=True),
            generate_button: gr.update(visible=False),
            cycle_generate_button: gr.update(visible=True),
            img2img_generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=False),
            depth2img_generate_button: gr.update(visible=False),
        }
    elif interface_name == "Img2Img Pipeline":
        return {
            negative_prompt: gr.update(visible=True),
            source_prompt: gr.update(visible=False),
            scheduler: gr.update(choices=[
                "EulerAncestralDiscreteScheduler",
                "EulerDiscreteScheduler",
                "DDIMScheduler",
                "DPMSolverMultistepScheduler",
                "LMSDiscreteScheduler",
                "PNDMScheduler"
                ],
                value="EulerAncestralDiscreteScheduler", 
            ),
            eta: gr.update(visible=True),
            strength: gr.update(visible=True),
            source_guidance_scale: gr.update(visible=False),
            image_input: gr.update(visible=True),
            generate_button: gr.update(visible=False),
            cycle_generate_button: gr.update(visible=False),
            img2img_generate_button: gr.update(visible=True),
            inpaint_generate_button: gr.update(visible=False),
            depth2img_generate_button: gr.update(visible=False),
        }
    elif interface_name == "Inpaint Pipeline":
        return {
            negative_prompt: gr.update(visible=True),
            source_prompt: gr.update(visible=False),
            scheduler: gr.update(choices=[
                "EulerAncestralDiscreteScheduler",
                "EulerDiscreteScheduler",
                "DDIMScheduler",
                "DPMSolverMultistepScheduler",
                "LMSDiscreteScheduler",
                "PNDMScheduler"
                ],
                value="EulerAncestralDiscreteScheduler", 
            ),
            eta: gr.update(visible=True),
            strength: gr.update(visible=False),
            source_guidance_scale: gr.update(visible=False),
            image_input: gr.update(visible=True),
            generate_button: gr.update(visible=False),
            cycle_generate_button: gr.update(visible=False),
            img2img_generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=True),
            depth2img_generate_button: gr.update(visible=False),
        } 
    elif interface_name == "Depth2Img Pipeline":
        return {
            negative_prompt: gr.update(visible=True),
            source_prompt: gr.update(visible=False),
            scheduler: gr.update(choices=[
                "EulerAncestralDiscreteScheduler",
                "EulerDiscreteScheduler",
                "DDIMScheduler",
                "DPMSolverMultistepScheduler",
                "LMSDiscreteScheduler",
                "PNDMScheduler"
                ],
                value="EulerAncestralDiscreteScheduler", 
            ),
            eta: gr.update(visible=True),
            strength: gr.update(visible=True),
            source_guidance_scale: gr.update(visible=False),
            image_input: gr.update(visible=True),
            generate_button: gr.update(visible=False),
            cycle_generate_button: gr.update(visible=False),
            img2img_generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=False),
            depth2img_generate_button: gr.update(visible=True),
        }


with gr.Blocks() as demo:
    gr.Markdown("Hugging Face Stable Diffusion")

    with gr.Tab("Main"):  
        interfaces_box = gr.Radio(
            label="Interface", 
            choices = [
                "Stable Diffusion pipeline", 
                "Cycle Diffusion pipeline",
                "Img2Img Pipeline",
                "Inpaint Pipeline",
                "Depth2Img Pipeline",
            ], 
            value = "Stable Diffusion pipeline")

        with gr.Row():
            with gr.Column(scale = 0.7):
                prompt = gr.Textbox(label="Prompt", lines = 3)
                
                source_prompt = gr.Textbox(label="Source Prompt", visible=False, lines = 3)
                negative_prompt = gr.Textbox(label="Negative Prompt", visible=True, lines = 3)
                
            with gr.Column(scale = 0.3):
                prompts_list = gr.Dropdown(
                    choices = get_styles(), label="Styles"
                    )
                load_prompts_button = gr.Button("Load Style")
                save_prompts_name = gr.Textbox(label="Save Style Name")
                save_prompts_button = gr.Button("Save Style")


        num_samples = gr.Slider(label="Number of samples", value = 1, step = 1, minimum = 1, maximum = 4)
        height = gr.Slider(label="Height", value = 512, step = 64, minimum = 64, maximum = 1024)
        width = gr.Slider(label="Width", value = 512, step = 64, minimum = 64, maximum = 1024)

        eta = gr.Slider(label="Eta", value = 0.1, step = 0.1, minimum = 0.1, maximum = 1.0, visible=False)
        strength = gr.Slider(label="Strength", value = 0.85, step = 0.05, minimum = 0.05, maximum = 1.0, visible=False)
        
        with gr.Row():
            
            with gr.Column():
                scheduler = gr.Dropdown(
                    label="Scheduler", 
                    choices=[
                        "EulerAncestralDiscreteScheduler", 
                        "EulerDiscreteScheduler", 
                        "DDIMScheduler",
                        "DPMSolverMultistepScheduler", 
                        "LMSDiscreteScheduler", 
                        "PNDMScheduler"
                    ],
                    value = "EulerAncestralDiscreteScheduler"
                )
                num_inference_steps = gr.Slider(label="Number of inference steps", value = 25, step = 1, minimum = 1, maximum = 100)
                guidance_scale = gr.Slider(label="Guidance scale", value = 7.5, step = 0.1, minimum = 0.1, maximum = 10.0)
                source_guidance_scale = gr.Slider(label="Source guidance scale", value = 1, step = 0.1, minimum = 0.1, maximum = 10.0, visible=False)
                
                seed = gr.Number(label="Seed", value = -1)
                generate_button = gr.Button("Generate")
                cycle_generate_button = gr.Button("Generate", visible=False)
                img2img_generate_button = gr.Button("Generate", visible=False)
                inpaint_generate_button = gr.Button("Generate", visible=False)
                depth2img_generate_button = gr.Button("Generate", visible=False)

            image_input = gr.Image(type="pil", tool='sketch', visible=False)
            image_output = gr.Gallery(show_label=False).style(grid=[2], height="auto", preview = True)

    with gr.Tab("Pathes"):
        with gr.Column():
            
            model_identifier = gr.Textbox(label="Model id or path", value = model_id)
            inpaint_model_identifier = gr.Textbox(label="Inpaint Model id or path", value = inpaint_model_id)
            depth2img_model_identifier = gr.Textbox(label="Depth2Img Model id or path", value = depth2img_model_id)

            save_path = gr.Textbox(label="Images save path", value = save_path)

            
    interfaces_box.change(
        select_interface, 
        interfaces_box, 
        [
            negative_prompt, 
            source_prompt, 
            scheduler,
            eta,
            strength,
            source_guidance_scale,
            image_input,
            generate_button,
            cycle_generate_button,
            img2img_generate_button,
            inpaint_generate_button,
            depth2img_generate_button
        ]
    )

    generate_button.click(
        generate_images, 
        inputs = [
            model_identifier,
            prompt, 
            negative_prompt,
            num_samples,
            guidance_scale,
            num_inference_steps,
            height,
            width,
            seed,
            save_path,
            scheduler
        ], 
        outputs=image_output
    )

    cycle_generate_button.click(
        cycle_generate_images,
        inputs = [
            model_identifier,
            prompt,
            source_prompt,
            num_samples,
            eta,
            strength,
            guidance_scale,
            source_guidance_scale,
            num_inference_steps,
            image_input,
            height,
            width,
            seed,
            save_path,
            scheduler
        
        ],
        outputs=image_output
    )

    img2img_generate_button.click(
        img2img_generate_images,
        inputs = [
            model_identifier,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            eta,
            strength,
            num_inference_steps,
            image_input,
            height,
            width,
            seed,
            save_path,
            scheduler
        ],
        outputs=image_output
    )

    inpaint_generate_button.click(
        inpaint_generate_images,
        inputs = [
            inpaint_model_identifier,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            eta,
            num_inference_steps,
            image_input,
            height,
            width,
            seed,
            save_path,
            scheduler
        ],
        outputs=image_output
    )

    depth2img_generate_button.click(
        depth2img_generate_images,
        inputs = [
            depth2img_model_identifier,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            eta,
            strength,
            num_inference_steps,
            image_input,
            height,
            width,
            seed,
            save_path,
            scheduler
        ],
        outputs=image_output
    )

    save_prompts_button.click(
        save_prompts,
        inputs = [
            prompt,
            negative_prompt,
            save_prompts_name,
        ],
        outputs = prompts_list
    )

    load_prompts_button.click(
        load_prompts,
        inputs = [
            prompts_list,
        ],
        outputs = [
            prompt,
            negative_prompt
        ]
    )

demo.queue()
demo.launch(share = True)