In [1]:
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL
import torch
import torch.nn.functional as nnf

from diffusers import DDIMScheduler, KandinskyPipeline

path = "kandinsky-community/kandinsky-2-2"

from PIL import Image
from torchvision import transforms

import inspect
from typing import Callable, List, Optional, Union

import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from diffusers.models import UNet2DConditionModel, VQModel
from diffusers.schedulers import DDIMScheduler, DDPMScheduler
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
# from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from diffusers import KandinskyPriorPipeline, KandinskyPipeline, KandinskyV22PriorPipeline, KandinskyV22Pipeline
from diffusers.utils import load_image
import PIL
from torchvision import transforms
from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
from diffusers.utils import deprecate, logging
from diffusers.utils.torch_utils import randn_tensor

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

  from .autonotebook import tqdm as notebook_tqdm
  from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput


In [2]:
def backward_ddim(x_t, alpha_t: "alpha_t", alpha_tm1: "alpha_{t-1}", eps_xt):
    """ from noise to image"""
    return (
        alpha_tm1**0.5
        * (
            (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
            + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
        )
        + x_t
    )

def forward_ddim(x_t, alpha_t: "alpha_t", alpha_tp1: "alpha_{t+1}", eps_xt):
    """ from image to noise, it's the same as backward_ddim"""
    return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)

In [3]:

def get_new_h_w(h, w, scale_factor=8):
    new_h = h // scale_factor**2
    if h % scale_factor**2 != 0:
        new_h += 1
    new_w = w // scale_factor**2
    if w % scale_factor**2 != 0:
        new_w += 1
    return new_h * scale_factor, new_w * scale_factor


In [4]:
def prepare_image(pil_image, w=512, h=512):
    pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
    arr = np.array(pil_image.convert("RGB"))
    arr = arr.astype(np.float16) / 127.5 - 1
    arr = np.transpose(arr, [2, 0, 1])
    image = torch.from_numpy(arr).unsqueeze(0)
    return image


class NewKandinskyPipeline(KandinskyV22Pipeline):
    def __init__(
        self,
        unet: UNet2DConditionModel,
        scheduler: Union[DDIMScheduler, DDPMScheduler],
        movq: VQModel,
    ):
        super().__init__(
            unet=unet,
            scheduler=scheduler,
            movq=movq,
        )
        self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
    
    @torch.inference_mode()
    def get_image_latents(self, image, sample=False, rng_generator=None):
        encoding_dist = self.movq.encode(image).latents
        # if sample:
        #     encoding = encoding_dist.sample(generator=rng_generator)
        # else:
        #     encoding = encoding_dist.mode()
        latents = encoding_dist #* 0.18215
        return latents

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start


    @torch.no_grad()
    def backward_diffusion(
        self,
        prompt: Union[str, List[str]],
        image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
        negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 100,
        guidance_scale: float = 4.0,
        num_images_per_prompt: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        return_dict: bool = True,
        reverse_process: bool = False,
        strength:float =0.3,
    ):

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        device = self._execution_device

        batch_size = batch_size * num_images_per_prompt
        do_classifier_free_guidance = guidance_scale > 1.0

        # prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
        #     prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        # )

        if isinstance(image_embeds, list):
            image_embeds = torch.cat(image_embeds, dim=0)
        if isinstance(negative_image_embeds, list):
            negative_image_embeds = torch.cat(negative_image_embeds, dim=0)

        if do_classifier_free_guidance:
            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
            negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

            image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
                dtype=torch.float16, device=device
            )

        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps_tensor = self.scheduler.timesteps
        timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

        num_channels_latents = self.unet.config.in_channels

        height, width = get_new_h_w(height, width, self.movq_scale_factor)

        # create initial latent
        latents = self.prepare_latents(
            (batch_size, num_channels_latents, height, width),
            torch.float16,
            device,
            generator,
            latents,
            self.scheduler,
        )

        for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

            added_cond_kwargs = {"image_embeds": image_embeds}
            noise_pred = self.unet(
                sample=latent_model_input,
                timestep=t,
                encoder_hidden_states=None,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )[0]

            if do_classifier_free_guidance:
                noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                _, variance_pred_text = variance_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)

            if not (
                hasattr(self.scheduler.config, "variance_type")
                and self.scheduler.config.variance_type in ["learned", "learned_range"]
            ):
                noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)

            prev_timestep = (
                t
                - self.scheduler.config.num_train_timesteps
                // self.scheduler.num_inference_steps
            )
            # ddim 
            alpha_prod_t = self.scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = (
                self.scheduler.alphas_cumprod[prev_timestep]
                if prev_timestep >= 0
                else self.scheduler.final_alpha_cumprod
            )
            
            if reverse_process:
                alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
                latents = backward_ddim(
                    x_t=latents,
                    alpha_t=alpha_prod_t,
                    alpha_tm1=alpha_prod_t_prev,
                    eps_xt=noise_pred,
                )
            else:
                latents = self.scheduler.step(
                    noise_pred,
                    t,
                    latents,
                    generator=generator,
                ).prev_sample
                

            if callback is not None and i % callback_steps == 0:
                step_idx = i // getattr(self.scheduler, "order", 1)
                callback(step_idx, t, latents)
        
        return latents 
        
    @torch.inference_mode()
    def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]:
        return self.movq.decode(latents, force_not_quantize=True)["sample"]

    @torch.inference_mode()
    def torch_to_numpy(self, image) -> List["PIL_IMAGE"]:
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        return image

In [5]:
from diffusers.models import PriorTransformer
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.schedulers import UnCLIPScheduler
from diffusers.utils import BaseOutput

class KandinskyPriorPipelineOutput(BaseOutput):
    """
    Output class for KandinskyPriorPipeline.

    Args:
        image_embeds (`torch.FloatTensor`)
            clip image embeddings for text prompt
        negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
            clip image embeddings for unconditional tokens
    """

    image_embeds: Union[torch.FloatTensor, np.ndarray]
    negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
    
class NewKandinskyPriorPipeline(KandinskyV22PriorPipeline):
    def __init__(
        self,
        prior: PriorTransformer,
        image_encoder: CLIPVisionModelWithProjection,
        text_encoder: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        scheduler: UnCLIPScheduler,
        image_processor: CLIPImageProcessor,
    ):
        super().__init__(
            prior=prior,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            scheduler=scheduler,
            image_encoder=image_encoder,
            image_processor=image_processor,
        )

    def _encode_image(
        self,
        image: Union[torch.Tensor, List[PIL.Image.Image]],
        device,
        num_images_per_prompt,
    ):
        if not isinstance(image, torch.Tensor):
            image = self.image_processor(image, return_tensors="pt").pixel_values.to(
                dtype=self.image_encoder.dtype, device=device
            )

        image_emb = self.image_encoder(image)["image_embeds"]  # B, D
        image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0)
        image_emb.to(device=device)

        return image_emb

    def prepare_latents_new(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
        emb = emb.to(device=device, dtype=dtype)

        batch_size = batch_size * num_images_per_prompt

        init_latents = emb

        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            raise ValueError(
                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
            )
        else:
            init_latents = torch.cat([init_latents], dim=0)

        shape = init_latents.shape
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        # get latents
        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
        latents = init_latents

        return latents

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start

    @torch.no_grad()
    def new_forward(
        self,
        prompt: Union[str, List[str]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: int = 1,
        num_inference_steps: int = 25,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        guidance_scale: float = 4.0,
        output_type: Optional[str] = "pt",
        return_dict: bool = True,
        image_pil = None,
        strength: float = 0.3,
    ):

        if isinstance(prompt, str):
            prompt = [prompt]
        elif not isinstance(prompt, list):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        elif not isinstance(negative_prompt, list) and negative_prompt is not None:
            raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

        # if the negative prompt is defined we double the batch size to
        # directly retrieve the negative prompt embedding
        if negative_prompt is not None:
            prompt = prompt + negative_prompt
            negative_prompt = 2 * negative_prompt

        device = self._execution_device

        batch_size = len(prompt)
        batch_size = batch_size * num_images_per_prompt

        do_classifier_free_guidance = guidance_scale > 1.0
        prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        )

        image_embeds = self._encode_image(image_pil, device, num_images_per_prompt)

        # prior
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        prior_timesteps_tensor = self.scheduler.timesteps

        embedding_dim = self.prior.config.embedding_dim

        latents = image_embeds
        prior_timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
        latent_timestep = prior_timesteps_tensor[:1].repeat(batch_size)
        latents = self.prepare_latents_new(
            latents,
            latent_timestep,
            batch_size // num_images_per_prompt,
            num_images_per_prompt,
            prompt_embeds.dtype,
            device,
            generator,
        )

        for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

            predicted_image_embedding = self.prior(
                latent_model_input,
                timestep=t,
                proj_embedding=prompt_embeds,
                encoder_hidden_states=text_encoder_hidden_states,
                attention_mask=text_mask,
            ).predicted_image_embedding

            if do_classifier_free_guidance:
                predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
                predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
                    predicted_image_embedding_text - predicted_image_embedding_uncond
                )

            if i + 1 == prior_timesteps_tensor.shape[0]:
                prev_timestep = None
            else:
                prev_timestep = prior_timesteps_tensor[i + 1]

            latents = self.scheduler.step(
                predicted_image_embedding,
                timestep=t,
                sample=latents,
                generator=generator,
                prev_timestep=prev_timestep,
            ).prev_sample

        latents = self.prior.post_process_latents(latents)

        image_embeddings = latents

        # if negative prompt has been defined, we retrieve split the image embedding into two
        if negative_prompt is None:
            zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)

        else:
            image_embeddings, zero_embeds = image_embeddings.chunk(2)

            if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
                self.prior_hook.offload()

        if output_type not in ["pt", "np"]:
            raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")

        if output_type == "np":
            image_embeddings = image_embeddings.cpu().numpy()
            zero_embeds = zero_embeds.cpu().numpy()

        if not return_dict:
            return (image_embeddings, zero_embeds)

        return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)

In [6]:
pipe = NewKandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config("kandinsky-community/kandinsky-2-2-decoder", subfolder="scheduler")
pipe = pipe.to("mps")

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


In [7]:
pipe_prior = NewKandinskyPriorPipeline.from_pretrained(
    "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)

pipe_prior.to("mps")

NewKandinskyPriorPipeline {
  "_class_name": "NewKandinskyPriorPipeline",
  "_diffusers_version": "0.18.2",
  "image_encoder": [
    "transformers",
    "CLIPVisionModelWithProjection"
  ],
  "image_processor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "prior": [
    "diffusers",
    "PriorTransformer"
  ],
  "scheduler": [
    "diffusers",
    "UnCLIPScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ]
}

In [8]:
def load_img(path, target_size=512):
    """Load an image, resize and output -1..1"""
    image = Image.open(path).convert("RGB")

    image = prepare_image(image)
    return image

In [None]:
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertForSequenceClassification
import torch.nn as nn
from transformers import AdamW
import os


class T2IModel(nn.Module):
    def __init__(self):
        super(T2IModel, self).__init__()
        self.text_model = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased', 
        )
        # self.text_model = CLIPTextModelWithProjection.from_pretrained(
        #     "kandinsky-community/kandinsky-2-2-prior", subfolder="text_encoder"
        # )
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="image_encoder"
        )
        self.fc = nn.Linear(2560, 1280)

    def initialize_optimizer(self):
        params = (
            list(self.fc.parameters())
        )
        optimizer = AdamW(params, lr=1e-4)
        return optimizer

    def forward(self, input_imgs, instructions):
        text_embeds = self.text_model(instructions).text_embeds
        vision_embeds = self.vision_model(input_imgs).image_embeds
        x = torch.cat((vision_embeds, text_embeds), 1)
        x = self.fc(x)
        return x

    def output_embedding(self, target_images):
        target_image_output = self.vision_model(target_images)
        target_image_embeds = target_image_output.image_embeds
        return target_image_embeds

    def custom_loss(self, output_embeddings, target_embeddings):
        mse_loss = nn.MSELoss()
        loss = mse_loss(output_embeddings, target_embeddings)

        return loss

    def save_model(self, output_dir="../model_save/", filename="model_checkpoint.pt"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_path = output_dir + filename
        print("Saving model to %s" % file_path)

        torch.save(model.state_dict(), file_path)

    def get_cos(self, input1, input2):
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cos(input1, input2)
        avg = torch.sum(similarity) / len(similarity)
        return avg

    def metrics(self, input1, input2):
        cos = self.get_cos(input1, input2)
        return [cos]

    def visualization(self, input_img, instruction, filename,negative_instruction = ""):
        output_embeddings = self.forward(input_img, instruction)
        neg_image_embed = self.forward(input_img, negative_instruction)

        pipe = KandinskyV22Pipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
        )
        pipe.to(device)
        image = pipe(
            image_embeds = output_embeddings,
            negative_image_embeds = neg_image_embed,
            height = 768,
            width = 768,
            num_inference_steps=100,
        ).images

        image[0].save(filename)


model = T2IModel()
model.to(device=device)

import torch

def load_model_from_checkpoint(model, checkpoint_path, device='cuda'):
    """
    Load a PyTorch model from a saved checkpoint.
    
    Parameters:
    - model (torch.nn.Module): The model architecture (untrained).
    - checkpoint_path (str): Path to the saved model checkpoint (.pth file).
    - device (str): Device to which the model should be loaded ('cuda' or 'cpu').

    Returns:
    - model (torch.nn.Module): Model populated with the loaded weights.
    """

    # Load the model state dictionary from the specified path
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the desired device
    model.to(device)
    
    return model

# Usage
loaded_model = load_model_from_checkpoint(model, 'magicbrush_kadinsky_imagewithinstruction_10epochs_full_v1.pth', device='cuda')

In [9]:
def get_edited_image(img_path, alternate_prompt, prompt=""):
    img = load_img(img_path, 512).to("mps", dtype=torch.float16)
    image_latents = pipe.get_image_latents(img)
    # image_emb_main, zero_image_emb_main = pipe_prior(prompt, generator=torch.Generator(device="mps").manual_seed(2023)).to_tuple()
    # image_emb_alternate, zero_image_emb_alternate = pipe_prior(alternate_prompt, generator=torch.Generator(device="mps").manual_seed(2023)).to_tuple()

    # image_emb_main, zero_image_emb_main = pipe_prior.new_forward(strength=0.6, prompt=prompt, image_pil=Image.open(img_path), generator=torch.Generator(device="mps").manual_seed(2023)).to_tuple()
    # image_emb_alternate, zero_image_emb_alternate = pipe_prior.new_forward(strength=0.6, prompt=alternate_prompt, image_pil=Image.open(img_path), generator=torch.Generator(device="mps").manual_seed(2023)).to_tuple()

    # zero_imge_emb_alternate -> embeddings of black image
    # zero_image_emb_main -> embeddings of black image
    # image_emb_main -> embeddings of input image
    # iage_emb_alternate -> embeddings of input image + prompt
    image_emb_alternate = loaded_model(image, prompt)
    zero_image_emb_alternate = loaded_model(black_image) 
    zero_image_emb_main = zero_image_emd_alternate
    image_emb_main = loaded_model(image, "")

    reversed_latents = pipe.forward_diffusion(
        "",
        image_embeds=image_emb_main,
        negative_image_embeds=zero_image_emb_main,
        guidance_scale=1,
        num_inference_steps=50,
        latents=image_latents,
        strength=0.5,
    )

    alternate_latents = pipe.backward_diffusion(
        "",
        image_embeds=image_emb_alternate,
        negative_image_embeds=zero_image_emb_main,
        guidance_scale=1,
        num_inference_steps=50,
        latents=reversed_latents,
        strength=0.5,
    )

    x = pipe.decode_image(alternate_latents)
    x = pipe.torch_to_numpy(x)
    return pipe.numpy_to_pil(x)[0]
    

In [13]:
impath = Path("../assets/test/example2.jpeg").expanduser()
get_edited_image(img_path=impath, alternate_prompt="That should be a truck and not a car.").save('output1.png')

100%|██████████| 15/15 [00:09<00:00,  1.64it/s]
100%|██████████| 15/15 [00:03<00:00,  4.97it/s]
100%|██████████| 25/25 [00:25<00:00,  1.01s/it]
100%|██████████| 25/25 [00:11<00:00,  2.22it/s]


In [14]:
impath = Path("../assets/test/example1.jpeg").expanduser()
get_edited_image(img_path=impath, alternate_prompt="Add mask to anyone of the player's face.").save('output2.png')

100%|██████████| 15/15 [00:10<00:00,  1.37it/s]
100%|██████████| 15/15 [00:03<00:00,  4.24it/s]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]
100%|██████████| 25/25 [00:11<00:00,  2.23it/s]


In [15]:
impath = Path("../assets/test/example.JPG").expanduser()
get_edited_image(img_path=impath, alternate_prompt="Change the STOP sign to GO sign.").save('output3.png')

100%|██████████| 15/15 [00:10<00:00,  1.42it/s]
100%|██████████| 15/15 [00:03<00:00,  4.92it/s]
100%|██████████| 25/25 [00:28<00:00,  1.12s/it]
100%|██████████| 25/25 [00:11<00:00,  2.24it/s]
