In [1]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print("There are %d GPU(s) available." % torch.cuda.device_count())

    print("We will use the GPU:", torch.cuda.get_device_name(0))

# If not...
elif torch.backends.mps.is_available():
    device = torch.device("mps")

    print("Using mps backend")
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [2]:
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL
import torch
import torch.nn.functional as nnf

from diffusers import DDIMScheduler, KandinskyPipeline

path = "kandinsky-community/kandinsky-2-2"

from PIL import Image
from torchvision import transforms

import inspect
from typing import Callable, List, Optional, Union

import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from diffusers.models import UNet2DConditionModel, VQModel
from diffusers.schedulers import DDIMScheduler, DDPMScheduler
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
# from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from diffusers import KandinskyPriorPipeline, KandinskyPipeline, KandinskyV22PriorPipeline, KandinskyV22Pipeline
from diffusers.utils import load_image
import PIL
from torchvision import transforms
from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
from diffusers.utils import deprecate, logging
from diffusers.utils.torch_utils import randn_tensor

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def backward_ddim(x_t, alpha_t: "alpha_t", alpha_tm1: "alpha_{t-1}", eps_xt):
    """ from noise to image"""
    return (
        alpha_tm1**0.5
        * (
            (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
            + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
        )
        + x_t
    )

def forward_ddim(x_t, alpha_t: "alpha_t", alpha_tp1: "alpha_{t+1}", eps_xt):
    """ from image to noise, it's the same as backward_ddim"""
    return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)

In [4]:

def get_new_h_w(h, w, scale_factor=8):
    new_h = h // scale_factor**2
    if h % scale_factor**2 != 0:
        new_h += 1
    new_w = w // scale_factor**2
    if w % scale_factor**2 != 0:
        new_w += 1
    return new_h * scale_factor, new_w * scale_factor


In [5]:
def prepare_image(pil_image, w=512, h=512):
    pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
    arr = np.array(pil_image.convert("RGB"))
    arr = arr.astype(np.float16) / 127.5 - 1
    arr = np.transpose(arr, [2, 0, 1])
    image = torch.from_numpy(arr).unsqueeze(0)
    return image


class NewKandinskyPipeline(KandinskyV22Pipeline):
    def __init__(
        self,
        unet: UNet2DConditionModel,
        scheduler: Union[DDIMScheduler, DDPMScheduler],
        movq: VQModel,
    ):
        super().__init__(
            unet=unet,
            scheduler=scheduler,
            movq=movq,
        )
        self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
    
    @torch.inference_mode()
    def get_image_latents(self, image, sample=False, rng_generator=None):
        encoding_dist = self.movq.encode(image).latents
        # if sample:
        #     encoding = encoding_dist.sample(generator=rng_generator)
        # else:
        #     encoding = encoding_dist.mode()
        latents = encoding_dist #* 0.18215
        return latents

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start


    @torch.no_grad()
    def backward_diffusion(
        self,
        prompt: Union[str, List[str]],
        image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
        negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 100,
        guidance_scale: float = 4.0,
        num_images_per_prompt: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        return_dict: bool = True,
        reverse_process: bool = False,
        strength:float =0.3,
    ):

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        device = self._execution_device

        batch_size = batch_size * num_images_per_prompt
        do_classifier_free_guidance = guidance_scale > 1.0

        # prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
        #     prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        # )

        if isinstance(image_embeds, list):
            image_embeds = torch.cat(image_embeds, dim=0)
        if isinstance(negative_image_embeds, list):
            negative_image_embeds = torch.cat(negative_image_embeds, dim=0)

        if do_classifier_free_guidance:
            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
            negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

            image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
                 device=device, dtype=torch.float16
            )

        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps_tensor = self.scheduler.timesteps
        timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

        num_channels_latents = self.unet.config.in_channels

        height, width = get_new_h_w(height, width, self.movq_scale_factor)

        # create initial latent
        latents = self.prepare_latents(
            (batch_size, num_channels_latents, height, width),
            torch.float16,
            device,
            generator,
            latents,
            self.scheduler,
        )

        for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

            added_cond_kwargs = {"image_embeds": image_embeds}
            noise_pred = self.unet(
                sample=latent_model_input,
                timestep=t,
                encoder_hidden_states=None,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )[0]

            if do_classifier_free_guidance:
                noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                _, variance_pred_text = variance_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)

            if not (
                hasattr(self.scheduler.config, "variance_type")
                and self.scheduler.config.variance_type in ["learned", "learned_range"]
            ):
                noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)

            prev_timestep = (
                t
                - self.scheduler.config.num_train_timesteps
                // self.scheduler.num_inference_steps
            )
            # ddim 
            alpha_prod_t = self.scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = (
                self.scheduler.alphas_cumprod[prev_timestep]
                if prev_timestep >= 0
                else self.scheduler.final_alpha_cumprod
            )
            
            if reverse_process:
                alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
                latents = backward_ddim(
                    x_t=latents,
                    alpha_t=alpha_prod_t,
                    alpha_tm1=alpha_prod_t_prev,
                    eps_xt=noise_pred,
                )
            else:
                latents = self.scheduler.step(
                    noise_pred,
                    t,
                    latents,
                    generator=generator,
                ).prev_sample
                

            if callback is not None and i % callback_steps == 0:
                step_idx = i // getattr(self.scheduler, "order", 1)
                callback(step_idx, t, latents)
        
        return latents 
        
    @torch.inference_mode()
    def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]:
        return self.movq.decode(latents, force_not_quantize=True)["sample"]

    @torch.inference_mode()
    def torch_to_numpy(self, image) -> List["PIL_IMAGE"]:
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        return image

In [6]:
from diffusers.models import PriorTransformer
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.schedulers import UnCLIPScheduler
from diffusers.utils import BaseOutput

class KandinskyPriorPipelineOutput(BaseOutput):
    """
    Output class for KandinskyPriorPipeline.

    Args:
        image_embeds (`torch.FloatTensor`)
            clip image embeddings for text prompt
        negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
            clip image embeddings for unconditional tokens
    """

    image_embeds: Union[torch.FloatTensor, np.ndarray]
    negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
    
class NewKandinskyPriorPipeline(KandinskyV22PriorPipeline):
    def __init__(
        self,
        prior: PriorTransformer,
        image_encoder: CLIPVisionModelWithProjection,
        text_encoder: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        scheduler: UnCLIPScheduler,
        image_processor: CLIPImageProcessor,
    ):
        super().__init__(
            prior=prior,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            scheduler=scheduler,
            image_encoder=image_encoder,
            image_processor=image_processor,
        )

    def _encode_image(
        self,
        image: Union[torch.Tensor, List[PIL.Image.Image]],
        device,
        num_images_per_prompt,
    ):
        if not isinstance(image, torch.Tensor):
            image = self.image_processor(image, return_tensors="pt").pixel_values.to(
                dtype=self.image_encoder.dtype, device=device
            )

        image_emb = self.image_encoder(image)["image_embeds"]  # B, D
        image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0)
        image_emb.to(device=device)

        return image_emb

    def prepare_latents_new(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
        emb = emb.to(device=device, dtype=dtype)

        batch_size = batch_size * num_images_per_prompt

        init_latents = emb

        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            raise ValueError(
                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
            )
        else:
            init_latents = torch.cat([init_latents], dim=0)

        shape = init_latents.shape
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        # get latents
        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
        latents = init_latents

        return latents

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start

    @torch.no_grad()
    def new_forward(
        self,
        prompt: Union[str, List[str]],
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: int = 1,
        num_inference_steps: int = 25,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        guidance_scale: float = 4.0,
        output_type: Optional[str] = "pt",
        return_dict: bool = True,
        image_pil = None,
        strength: float = 0.3,
    ):

        if isinstance(prompt, str):
            prompt = [prompt]
        elif not isinstance(prompt, list):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        elif not isinstance(negative_prompt, list) and negative_prompt is not None:
            raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

        # if the negative prompt is defined we double the batch size to
        # directly retrieve the negative prompt embedding
        if negative_prompt is not None:
            prompt = prompt + negative_prompt
            negative_prompt = 2 * negative_prompt

        device = self._execution_device

        batch_size = len(prompt)
        batch_size = batch_size * num_images_per_prompt

        do_classifier_free_guidance = guidance_scale > 1.0
        prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        )

        image_embeds = self._encode_image(image_pil, device, num_images_per_prompt)

        # prior
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        prior_timesteps_tensor = self.scheduler.timesteps

        embedding_dim = self.prior.config.embedding_dim

        latents = image_embeds
        prior_timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
        latent_timestep = prior_timesteps_tensor[:1].repeat(batch_size)
        latents = self.prepare_latents_new(
            latents,
            latent_timestep,
            batch_size // num_images_per_prompt,
            num_images_per_prompt,
            prompt_embeds.dtype,
            device,
            generator,
        )

        for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

            predicted_image_embedding = self.prior(
                latent_model_input,
                timestep=t,
                proj_embedding=prompt_embeds,
                encoder_hidden_states=text_encoder_hidden_states,
                attention_mask=text_mask,
            ).predicted_image_embedding

            if do_classifier_free_guidance:
                predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
                predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
                    predicted_image_embedding_text - predicted_image_embedding_uncond
                )

            if i + 1 == prior_timesteps_tensor.shape[0]:
                prev_timestep = None
            else:
                prev_timestep = prior_timesteps_tensor[i + 1]

            latents = self.scheduler.step(
                predicted_image_embedding,
                timestep=t,
                sample=latents,
                generator=generator,
                prev_timestep=prev_timestep,
            ).prev_sample

        latents = self.prior.post_process_latents(latents)

        image_embeddings = latents

        # if negative prompt has been defined, we retrieve split the image embedding into two
        if negative_prompt is None:
            zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)

        else:
            image_embeddings, zero_embeds = image_embeddings.chunk(2)

            if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
                self.prior_hook.offload()

        if output_type not in ["pt", "np"]:
            raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")

        if output_type == "np":
            image_embeddings = image_embeddings.cpu().numpy()
            zero_embeds = zero_embeds.cpu().numpy()

        if not return_dict:
            return (image_embeddings, zero_embeds)

        return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)

In [7]:
pipe = NewKandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder",
                                            torch_dtype=torch.float16
                                           )
pipe.scheduler = DDIMScheduler.from_config("kandinsky-community/kandinsky-2-2-decoder", subfolder="scheduler")
pipe = pipe.to("cuda")

Downloading model_index.json: 100%|██████████████████████| 250/250 [00:00<00:00, 2.31MB/s]
Fetching 6 files:   0%|                                             | 0/6 [00:00<?, ?it/s]
Downloading unet/config.json: 100%|██████████████████| 1.67k/1.67k [00:00<00:00, 18.7MB/s][A

Downloading movq/config.json: 100%|██████████████████████| 660/660 [00:00<00:00, 8.54MB/s][A

Downloading (…)cheduler_config.json: 100%|███████████████| 317/317 [00:00<00:00, 3.97MB/s][A
Fetching 6 files:  33%|████████████▎                        | 2/6 [00:00<00:00,  6.94it/s]
Downloading (…)ch_model.safetensors:   0%|                     | 0.00/271M [00:00<?, ?B/s][A

Downloading (…)ch_model.safetensors:   0%|                    | 0.00/5.01G [00:00<?, ?B/s][A[A
Downloading (…)ch_model.safetensors:   4%|▍           | 10.5M/271M [00:00<00:03, 73.6MB/s][A

Downloading (…)ch_model.safetensors:   0%|            | 21.0M/5.01G [00:00<00:35, 142MB/s][A[A
Downloading (…)ch_model.safetensors:   8%|▉           | 21.

In [8]:
def load_img(path, target_size=512):
    """Load an image, resize and output -1..1"""
    image = Image.open(path).convert("RGB")

    image = prepare_image(image)
    return image

In [10]:
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertModel
import torch.nn as nn
from transformers import AdamW
import os


class T2IModel(nn.Module):
    def __init__(self):
        super(T2IModel, self).__init__()
        self.text_model = BertModel.from_pretrained('bert-base-uncased')
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="image_encoder"
        )
        # Adjust the input features of the FC layer to the combined size of text and vision outputs
        self.fc = nn.Linear(self.text_model.config.hidden_size + self.vision_model.config.projection_dim, 1280)
        self.pipe = KandinskyV22Pipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
        )

    def initialize_optimizer(self):
        params = (
            list(self.fc.parameters())
        )
        optimizer = AdamW(params, lr=1e-4)
        return optimizer

    def forward(self, input_imgs, input_txt, attention_mask=None):
        text_outputs = self.text_model(input_txt, attention_mask=attention_mask)
        text_embeds = text_outputs.last_hidden_state[:, 0, :]  # Use the representation of the [CLS] token

        vision_outputs = self.vision_model(input_imgs)
        vision_embeds = vision_outputs.image_embeds

        combined_embeds = torch.cat((vision_embeds, text_embeds), dim=1)
        x = self.fc(combined_embeds)
        return x

    def output_embedding(self, target_images):
        target_image_output = self.vision_model(target_images)
        target_image_embeds = target_image_output.image_embeds
        return target_image_embeds

    def custom_loss(self, output_embeddings, target_embeddings):
        mse_loss = nn.MSELoss()
        loss = mse_loss(output_embeddings, target_embeddings)

        return loss

    def save_model(self, output_dir="../model_save/", filename="model_checkpoint.pt"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_path = output_dir + filename
        print("Saving model to %s" % file_path)

        torch.save(model.state_dict(), file_path)

    def get_cos(self, input1, input2):
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cos(input1, input2)
        avg = torch.sum(similarity) / len(similarity)
        return avg

    def metrics(self, input1, input2):
        cos = self.get_cos(input1, input2)
        return [cos]

    def visualization(self, input_img, instruction, instruction_attention_mask, filename, negative_instruction=None, negative_instruction_attention_mask=None):
        # Generate output embeddings with the provided attention mask
        output_embeddings = self.forward(input_img, instruction, attention_mask=instruction_attention_mask)

        # Handle the negative instruction if provided
        neg_image_embed = None
        if negative_instruction is not None and negative_instruction_attention_mask is not None:
            neg_image_embed = self.forward(input_img, negative_instruction, attention_mask=negative_instruction_attention_mask)
        else:
            # If no negative instruction is provided, we'll use a tensor of zeros as a placeholder
            neg_image_embed = torch.zeros_like(output_embeddings)

        # Initialize the pipeline for the Kandinsky V2.2 decoder
        
        self.pipe.to(device)  # Make sure 'self.device' is defined in your model class

        # Generate the image using the pipeline
        image = self.pipe(
            image_embeds=output_embeddings,
            negative_image_embeds=neg_image_embed,
            height=768,
            width=768,
            num_inference_steps=100,
        ).images

        # Save the generated image
        image[0].save(filename)

model = T2IModel()
model.to(device=device)

Downloading config.json: 100%|███████████████████████████| 570/570 [00:00<00:00, 3.48MB/s]
Downloading model.safetensors: 100%|████████████████████| 440M/440M [00:01<00:00, 379MB/s]
Loading pipeline components...: 100%|███████████████████████| 3/3 [00:01<00:00,  2.51it/s]


T2IModel(
  (text_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [11]:
import torch

def load_model_from_checkpoint(model, checkpoint_path, device='cuda'):
    """
    Load a PyTorch model from a saved checkpoint.
    
    Parameters:
    - model (torch.nn.Module): The model architecture (untrained).
    - checkpoint_path (str): Path to the saved model checkpoint (.pth file).
    - device (str): Device to which the model should be loaded ('cuda' or 'cpu').

    Returns:
    - model (torch.nn.Module): Model populated with the loaded weights.
    """

    # Load the model state dictionary from the specified path
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the desired device
    model.to(device)
    
    return model

# Usage
loaded_model = load_model_from_checkpoint(model, '/scratch/nkusumba/magicbrush_kadinsky_bert_imagewithinstruction_10epochs_full_v1.pth', device='cuda')

In [12]:
from PIL import Image

def create_black_image(width, height):
    # Create a black image using PIL
    return Image.new("RGB", (width, height), (0, 0, 0))

def get_black_image():
    # Create and return a black image
    black_image = Image.new("RGB", (224, 224), (0, 0, 0))  # Adjust size as needed
    return black_image

In [14]:
from transformers import AutoTokenizer,AutoProcessor
from torchvision import transforms
def get_edited_image(img_path, tokenizer, processor, alternate_prompt):
    empty_prompt=""
    img = load_img(img_path, 512).to(device,
                                     dtype = torch.float16
                                    )
    image_latents = pipe.get_image_latents(img)
    
    # Convert image to tensor and ensure consistency in data type
    img = Image.open(img_path)
    input_image = processor(images=img, return_tensors="pt")["pixel_values"].to(device)
    black_image = processor(images=get_black_image(), return_tensors="pt")["pixel_values"].to(device)

    # Tokenize the prompts
    inputs = tokenizer(empty_prompt, return_tensors="pt").to(device)
    inputs_alternate = tokenizer(alternate_prompt, return_tensors="pt").to(device)

    image_emb_alternate = loaded_model(input_image, inputs_alternate["input_ids"]).to(dtype=torch.float16)
    zero_image_emb_alternate = loaded_model.vision_model(black_image).image_embeds.to(dtype=torch.float16)
    
    zero_image_emb_main = zero_image_emb_alternate
    image_emb_main = loaded_model.vision_model(input_image).image_embeds.to(dtype=torch.float16)

    
    reversed_latents = pipe.forward_diffusion(
        "",
        image_embeds=image_emb_main,
        negative_image_embeds=zero_image_emb_main,
        guidance_scale=1,
        num_inference_steps=100,
        latents=image_latents,
        strength=1.5,
    )

    alternate_latents = pipe.backward_diffusion(
        "",
        image_embeds=image_emb_alternate,
        negative_image_embeds=zero_image_emb_main,
        guidance_scale=1,
        num_inference_steps=100,
        latents=reversed_latents,
        strength=1.5,
    )

    x = pipe.decode_image(alternate_latents)
    x = pipe.torch_to_numpy(x)
    return pipe.numpy_to_pil(x)[0]
    

In [None]:
import os
import json

file_path = '/scratch/nkusumba/test/edit_sessions.json'

with open(file_path, 'r') as file:
    json_data = json.load(file)

# Printing each key and its corresponding value
dic = {}
for key, value in json_data.items():
    dic[key] = value[0]['instruction']

images_path = '/scratch/nkusumba/test/images/'
os.makedirs('/scratch/nkusumba/test/outputs/', exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")


print('Started Evaluation')
count = -1
dir_name = []
for dirpath, dirname, filenames in os.walk(images_path):
    if count == -1:
        count = 0
        dir_name = dirname
    if count == 100:
        print('Process done!!!')
        break
    input_path = ''
    output_path = ''
    for file in filenames:
        filepath = os.path.join(dirpath, file)
        if filepath.endswith('input.png'):
            input_path = filepath
        elif filepath.endswith('output1.png'):
            output_path = filepath
    if input_path == '':
        continue
    print(f'Processing {count+1}th image')
    dir = f'/scratch/nkusumba/test/outputs/{count+1}'
    os.makedirs(dir, exist_ok=True)
    out_img = Image.open(output_path)
    out_img.save(f'{dir}/groundtruth.png')

    # Process the image
    img = Image.open(input_path)
    img.save(f'{dir}/input_image.png')

    # Process the instruction
    instruction = dic[dir_name[count]]
    with open(f'{dir}/instruction.txt', 'w') as f:
        f.write(instruction)
        
    # Visualize the output
    out_img = get_edited_image(input_path, tokenizer, processor, instruction)
    out_img.save(f'{dir}/output.png')
    print(f'Finished processing {count+1}th image')
    count += 1


Started Evaluation
Processing 1th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.62it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 31.67it/s]


Finished processing 1th image
Processing 2th image


100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 31.27it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 33.17it/s]


Finished processing 2th image
Processing 3th image


100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 32.25it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 31.84it/s]


Finished processing 3th image
Processing 4th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.27it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.58it/s]


Finished processing 4th image
Processing 5th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.29it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.90it/s]


Finished processing 5th image
Processing 6th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.09it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.90it/s]


Finished processing 6th image
Processing 7th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.57it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.91it/s]


Finished processing 7th image
Processing 8th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.40it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.05it/s]


Finished processing 8th image
Processing 9th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.46it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.99it/s]


Finished processing 9th image
Processing 10th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.81it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.77it/s]


Finished processing 10th image
Processing 11th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.58it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 32.07it/s]


Finished processing 11th image
Processing 12th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.55it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.10it/s]


Finished processing 12th image
Processing 13th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.41it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.06it/s]


Finished processing 13th image
Processing 14th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.61it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.88it/s]


Finished processing 14th image
Processing 15th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.36it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 32.93it/s]


Finished processing 15th image
Processing 16th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.40it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.78it/s]


Finished processing 16th image
Processing 17th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.11it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 28.23it/s]


Finished processing 17th image
Processing 18th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.02it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.58it/s]


Finished processing 18th image
Processing 19th image


100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.04it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.84it/s]


Finished processing 19th image
Processing 20th image


100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.19it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.62it/s]


Finished processing 20th image
Processing 21th image


100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 28.70it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 27.06it/s]


Finished processing 21th image
Processing 22th image


100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 30.10it/s]
100%|███████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.83it/s]


Finished processing 22th image
Processing 23th image
