In [None]:
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch
from PIL import Image
import os
#todo
"""
Этап реализациия класса

1) сделать возможность выбора патча и/или объкта изображения для улучшения - Done, Tested - Done
2) улавливание ошибок, очистка памяти, вывод котдов ошибок - Done, Tested - Done
3) сохранение изображений именно как объектов  - Done, Tested - Done

Этап реализация приложения

4) выбор других моделей - None, Tested - None
5) дообучениемоделей - None, Tested - None
6) графическая оболочка - NOne, Tested - None
"""

class Visualizer:

    def __init__(self, base_model_id, refiner_model_id):
        self.base_model_id = base_model_id
        self.refiner_model_id = refiner_model_id
        self.pipe = None
        self.image_path = None
        self.image_refined_path = None
        self.image = None
        self.refined_image = None

    def load_base_model(self):
        self.pipe = StableDiffusionXLPipeline.from_pretrained(self.base_model_id, torch_dtype=torch.float16)
        self.pipe.safety_checker = None
        #self.pipe = self.pipe.to("cuda")

    def unload_model(self):
        del self.pipe
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    def create_image(self, prompt, negative_prompt, height=1536, width=1024 ,guidance_scale=15, save_image = False, num_inference_steps=70,path_to_save = "output_base.png", num_images_per_prompt=1, generator = None):
        self.load_base_model()

        try:
            if generator:
                image = self.pipe(prompt=prompt, negative_prompt=negative_prompt, height=1536, width=1024,\
                            guidance_scale=guidance_scale,\
                                num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator = generator).images[0]
                self.image_path = path_to_save         
            else:
                image = self.pipe(prompt=prompt, negative_prompt=negative_prompt, height=1536, width=1024,\
                            guidance_scale=guidance_scale,\
                                num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt).images[0]
                self.image_path = path_to_save
            if save_image:
                image.save(path_to_save)
            self.image = image
            self.unload_model()
            return image
        except Exception as e:
            self.unload_model()
            return f"error: {e}"

    def load_refiner_model(self):
        self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(refiner_model_id, torch_dtype=torch.float16)
        self.pipe = self.pipe.to("cuda")

    def refine_image(self, prompt, strength=0.5, guidance_scale=15, save_image = False ,path_to_save = "output_refined.png", num_inference_steps=25,image = None, image_path = None, num_images_per_prompt=1, generator = None):
        if image_path:
            try:
                with open(image_path, "r") as file:
                    pass
            except Exception as e:
                return f"error: {e}"
        self.load_refiner_model()

        try:
            if generator:
                if image:
                    refined_image = self.pipe(prompt=prompt, image=image,\
                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator = generator).images[0]
                elif image_path:
                    image = Image.open(image_path)
                    refined_image = self.pipe(prompt=prompt, image=image,\
                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator = generator).images[0]
                else:
                    refined_image = self.pipe(prompt=prompt, image=self.image,\
                                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator = generator).images[0]
            else:
                if image:
                    refined_image = self.pipe(prompt=prompt, image=image,\
                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt).images[0]
                elif image_path:
                    image = Image.open(image_path)
                    refined_image = self.pipe(prompt=prompt, image=image,\
                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt).images[0]
                else:
                    refined_image = self.pipe(prompt=prompt, image=self.image,\
                                        strength=strength, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt).images[0]
            if save_image:
                refined_image.save(path_to_save)
            self.image_refined_path = path_to_save
            self.refined_image = refined_image
            self.unload_model()
            return refined_image
        except Exception as e:
            self.unload_model()
            return f"error: {e}"
